package tcode

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"strings"
	"time"
)

type mapping uint8
type ignored uint8

const (
	noMapping mapping = iota
	resultSingleVar
	resultStruct
	resultSlice
	resultRawTable
	// notIgnoredEveryColumn 不忽略任何一列
	notIgnoredEveryColumn ignored = iota
	// ignoredEveryEmptyColumn 忽略每一个空列
	ignoredEveryEmptyColumn
)

type TableSlice[T Table] []T

type rawTableSqlInfo interface {
	ToVar(ptrVar ...any) error
	ToRawTable(page *page) (*rawTable, error)
	Exec() (rowsAffected int64, lastInsertId int64, err error)
	String() (sqlStr string, err error)
}
type sqlInfo[T Table] struct {
	ctx          context.Context
	statement    strings.Builder
	params       []any
	result       T
	resultSlice  *[]T
	rawTable     *rawTable
	ptrVar       []any
	mapping      mapping
	LastInsertId int64
	RowsAffected int64
	//updateMechanism 更新策略，默认忽略每一个空列, notIgnoredEveryColumn 不忽略任何一列
	updateMechanism ignored
	// possibleQueryEmptyColumn 当数据列可能存在（nil）空值时，请设置possibleQueryEmptyColumn=true，设置为true时效率会较低
	//推荐数据库默认值不要为null，都应该有对应的默认值
	possibleQueryEmptyColumn bool
}

func (sqlInfo *sqlInfo[T]) ToVar(ptrVar ...any) (err error) {
	for i := range ptrVar {
		if ptrVar[i] == nil {
			err = errors.New(fmt.Sprintf("err: ptrVar index[%d] is nil", i))
			FuncLog(err)
			return err
		}
	}
	sqlInfo.ptrVar = append(sqlInfo.ptrVar, ptrVar...)
	sqlInfo.mapping = resultSingleVar
	err = sqlInfo.query()
	if err != nil {
		FuncLog(err)
		return err
	}
	return nil
}

// ToStruct 获取结构体数据
func (sqlInfo *sqlInfo[T]) ToStruct() (t T, err error) {
	sqlInfo.result = sqlInfo.result.NewInstance().(T)
	sqlInfo.mapping = resultStruct
	err = sqlInfo.query()
	if err != nil {
		return t, err
	}
	return sqlInfo.result, nil
}

// ToSlice 获取切片数据
func (sqlInfo *sqlInfo[T]) ToSlice(page *page) (*[]T, error) {
	sqlInfo.resultSlice = new([]T)
	sqlInfo.mapping = resultSlice
	if page != nil {
		var total int
		var err error
		if page.FuncCustomTotal == nil {
			var sqlStr string
			sqlStr, err = sqlInfo.String()
			if err != nil {
				FuncLog(err)
				return nil, err
			}
			info := NewSqlInfo(sqlInfo.ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS count_query", sqlStr), sqlInfo.params...)
			err = info.ToVar(&total)
		} else {
			err = page.FuncCustomTotal(&total)
		}
		if err != nil {
			FuncLog(err)
			return nil, err
		}
		page.setTotalCount(total)
		if total <= 0 {
			err := errors.New("no data")
			return &[]T{}, err
		}
		sqlInfo.AppendSQL(fmt.Sprintf(" LIMIT %d OFFSET %d ", page.PageSize, (page.CurrentPage-1)*page.PageSize))
	}
	err := sqlInfo.query()
	if err != nil {
		return nil, err
	}
	return sqlInfo.resultSlice, nil
}

// ToRawTable 获取原始表数据
func (sqlInfo *sqlInfo[T]) ToRawTable(page *page) (*rawTable, error) {
	sqlInfo.mapping = resultRawTable
	if page != nil {
		var total int
		var err error
		if page.FuncCustomTotal == nil {
			var sqlStr string
			sqlStr, err = sqlInfo.String()
			if err != nil {
				FuncLog(err)
				return sqlInfo.rawTable, err
			}
			info := NewSqlInfo(sqlInfo.ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS count_query", sqlStr), sqlInfo.params...)
			err = info.ToVar(&total)
		} else {
			err = page.FuncCustomTotal(&total)
		}
		if err != nil {
			FuncLog(err)
			return sqlInfo.rawTable, err
		}
		page.setTotalCount(total)
		if total <= 0 {
			err := errors.New("no data")
			return sqlInfo.rawTable, err
		}
		sqlInfo.AppendSQL(fmt.Sprintf(" LIMIT %d OFFSET %d ", page.PageSize, (page.CurrentPage-1)*page.PageSize))
	}
	err := sqlInfo.query()
	if err != nil {
		return sqlInfo.rawTable, err
	}
	return sqlInfo.rawTable, nil
}

// String 处理最终执行的sql
func (sqlInfo *sqlInfo[T]) String() (sqlStr string, err error) {
	// sql查询条件 in 切片处理
	err = handlerSqlIn(sqlInfo)
	if err != nil {
		return "", err
	}
	sqlStr = sqlInfo.statement.String()

	if strings.Contains(sqlStr, "'") {
		err = errors.New("warn: sqlStr statement contains [']")
		FuncLog(err)
	}
	if getContextConf(sqlInfo.ctx).DebugSQL {
		split := strings.Split(sqlStr, "?")
		params := sqlInfo.params
		paramsLen := len(params)
		if len(split)-1 != paramsLen {
			err = errors.New(fmt.Sprintf("err: parameter conditions do not match [?]:%d; value:%d", len(split)-1, paramsLen))
			FuncLog(err)
			return "", err
		}
		sqlBuilderTemp := strings.Builder{}
		for i := 0; i < paramsLen; i++ {
			sqlBuilderTemp.WriteString(split[i])
			sqlBuilderTemp.WriteString(singleQ + ConvertToString(params[i]) + singleQ)
		}
		sqlBuilderTemp.WriteString(split[paramsLen])
		FuncLog("debug sqlStr:", sqlBuilderTemp.String())
	}
	return sqlStr, err
}

func (sqlInfo *sqlInfo[T]) AppendSQL(statement string, param ...any) *sqlInfo[T] {
	sqlInfo.statement.WriteString(statement)
	sqlInfo.params = append(sqlInfo.params, param...)
	return sqlInfo
}

func handlerSqlIn[T Table](sqlInfo *sqlInfo[T]) error {
	pSql := []byte(sqlInfo.statement.String())
	//查找in所在的位置
	index := findInReg.FindAllIndex(pSql, -1)
	ii := len(index)
	if ii <= 0 {
		return nil
	}
	var resultSql []byte
	var resultParam []any
	a := 0
	b := 0
	for i := range pSql {
		if a < ii && pSql[i] == '?' && i > index[a][0] && i < index[a][1] {
			slice := ConvertToStringSlice(sqlInfo.params[b])
			if len(slice) <= 0 {
				err := errors.New("slice len is 0")
				FuncLog(err)
				return err
			}
			for j := range slice {
				if j > 0 {
					resultSql = append(resultSql, ',')
				}
				resultSql = append(resultSql, '?')
				resultParam = append(resultParam, slice[j])
			}
			a++
		} else if pSql[i] == '?' {
			resultSql = append(resultSql, pSql[i])
			resultParam = append(resultParam, sqlInfo.params[b])
		} else {
			resultSql = append(resultSql, pSql[i])
		}
		if pSql[i] == '?' {
			b++
		}
	}
	//将只有一个数据的 in 条件修改成 = ; 当多次调用这个方法是不会再次处理已经匹配过的,减少处理次数
	index = findInReg.FindAllIndex(resultSql, -1)
	for i := range index {
		for j := index[i][0]; j < index[i][1]; j++ {
			resultSql[j] = ' '
		}
		resultSql[index[i][0]] = '='
		resultSql[index[i][0]+1] = '?'
	}
	//不丢参数
	for i := b; i < len(sqlInfo.params); i++ {
		resultParam = append(resultParam, sqlInfo.params[i])
	}
	sqlInfo.statement.Reset()
	sqlInfo.statement.Write(resultSql)
	sqlInfo.params = resultParam
	return nil
}

// PossibleQueryEmptyColumn 当数据列可能存在（nil）空值时，请设置possibleQueryEmptyColumn=true，设置为true时效率会较低
// 推荐数据库默认值不要为null，都应该有对应的默认值
func (sqlInfo *sqlInfo[T]) PossibleQueryEmptyColumn() *sqlInfo[T] {
	sqlInfo.possibleQueryEmptyColumn = true
	return sqlInfo
}

func (sqlInfo *sqlInfo[T]) query() error {
	sqlStr, err := sqlInfo.String()
	if err != nil {
		FuncLog(err)
		return err
	}
	var rows *sql.Rows
	tx := GetContextTxConn(sqlInfo.ctx)
	startTime := time.Now().UnixNano()
	if tx != nil {
		rows, err = tx.QueryContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
	} else {
		db := GetContextDBConn(sqlInfo.ctx)
		if db == nil {
			err = errors.New("please set ctx dbConn")
			FuncLog(err)
			return err
		}
		rows, err = db.QueryContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
	}
	endTime := time.Now().UnixNano()
	FuncSQLLog(float64(endTime-startTime)/1e6, sqlStr, sqlInfo.params)
	if err != nil {
		FuncLog(err)
		return err
	}
	defer func() {
		err = rows.Close()
		if err != nil {
			FuncLog(err)
		}
	}()
	columns, err := rows.Columns()
	if err != nil {
		FuncLog(err)
		return err
	}
	switch sqlInfo.mapping {
	case noMapping:
		FuncLog("warn: no mapping")
		return nil
	case resultSingleVar:
		if rows.Next() {
			if sqlInfo.possibleQueryEmptyColumn {
				err = reTryScan(sqlInfo.ptrVar, rows)
			} else {
				err = rows.Scan(sqlInfo.ptrVar...)
			}
			if err != nil {
				FuncLog(err)
				return err
			}
		}
		if rows.Next() {
			err = errors.New("err: require one row data,but returns multi row")
			FuncLog(err)
			return err
		}
	case resultStruct:
		sqlInfo.ptrVar = sqlInfo.result.RawColumnContainer(columns...)
		if rows.Next() {
			if sqlInfo.possibleQueryEmptyColumn {
				err = reTryScan(sqlInfo.ptrVar, rows)
			} else {
				err = rows.Scan(sqlInfo.ptrVar...)
			}
			if err != nil {
				FuncLog(err)
				return err
			}
		}
		if rows.Next() {
			err = errors.New("err: require one row data,but returns multi row")
			FuncLog(err)
			return err
		}
	case resultSlice:
		//用于恢复结构体属性类型的默认值
		var empty = sqlInfo.result.NewInstance()
		var sliceTable = sqlInfo.result.NewInstance()
		sqlInfo.ptrVar = sliceTable.RawColumnContainer(columns...)
		for i := 0; rows.Next(); i++ {
			if sqlInfo.possibleQueryEmptyColumn {
				err = reTryScan(sqlInfo.ptrVar, rows)
			} else {
				err = rows.Scan(sqlInfo.ptrVar...)
			}
			if err != nil {
				FuncLog(err)
				return err
			}
			*sqlInfo.resultSlice = append(*sqlInfo.resultSlice, sliceTable.NewTable().(T))
			sliceTable.CopyFrom(empty) //清空上一行数据(恢复属性默认值)
		}
	case resultRawTable:
		for rows.Next() {
			sqlInfo.ptrVar = sqlInfo.rawTable.RawColumnContainer(columns...)
			err = rows.Scan(sqlInfo.ptrVar...)
			if err != nil {
				FuncLog(err)
				return err
			}
		}
		sqlInfo.rawTable.columnNames = columns
	default:
		err = errors.New("err: unknown mapping type")
		FuncLog(err)
		return err

	}
	return nil
}

// insert 属性默认值遵循go语言基本类型的默认值 int=0;string="";time="0000-01-01 00:00:00" ...等
// 新增时的非自增的主键由调用方处理,可使用主键生成工具类(tcode.FuncGenId)生成
func (sqlInfo *sqlInfo[T]) insert(data T) (rowsAffected int64, lastInsertId int64, err error) {
	sqlInfo.result = data
	columns := data.Columns()
	container := data.RawColumnContainer(columns...)
	sqlInfo.AppendSQL("INSERT INTO " + data.TableName() + " VALUES (")
	for i := range container {
		if i > 0 {
			sqlInfo.AppendSQL(",")
		}
		sqlInfo.AppendSQL("?", ConvertToString(container[i]))
	}
	sqlInfo.AppendSQL(")")
	rowsAffected, lastInsertId, err = sqlInfo.Exec()
	if lastInsertId != 0 && err == nil { // 处理自增主键回显至结构体
		index := StringInIndex(GetPkColumnName(sqlInfo.ctx), columns)
		if index <= -1 {
			return rowsAffected, lastInsertId, err
		}
		pkVal := container[index]
		switch p := pkVal.(type) {
		case *int8:
			*p = int8(lastInsertId)
		case *int16:
			*p = int16(lastInsertId)
		case *int32:
			*p = int32(lastInsertId)
		case *int64:
			*p = lastInsertId
		case *uint8:
			*p = uint8(lastInsertId)
		case *uint16:
			*p = uint16(lastInsertId)
		case *uint32:
			*p = uint32(lastInsertId)
		case *uint64:
			*p = uint64(lastInsertId)
		}
	}
	return rowsAffected, lastInsertId, err
}
func (sqlInfo *sqlInfo[T]) insertBatch(datas *[]T) (rowsAffected int64, lastInsertId int64, err error) {
	ds := *datas
	if len(ds) <= 0 {
		return rowsAffected, lastInsertId, err
	}
	sqlInfo.result = ds[0]
	sqlInfo.AppendSQL("INSERT INTO " + ds[0].TableName() + " VALUES ")
	for i := range ds {
		if i > 0 {
			sqlInfo.AppendSQL(",")
		}
		sqlInfo.AppendSQL("(")
		container := ds[i].RawColumnContainer(ds[i].Columns()...)
		for j := range container {
			if j > 0 {
				sqlInfo.AppendSQL(",")
			}
			sqlInfo.AppendSQL("?", ConvertToString(container[j]))
		}
		sqlInfo.AppendSQL(")")
	}
	return sqlInfo.Exec()
}

// UpdateByPk 更新策略，默认忽略每一个空列, notIgnoredEveryColumn 不忽略任何一列
func (sqlInfo *sqlInfo[T]) updateByPk(data T) (rowsAffected int64, lastInsertId int64, err error) {
	sqlInfo.result = data
	sqlInfo.AppendSQL("UPDATE " + data.TableName() + " SET ")
	columns := data.Columns()
	container := data.RawColumnContainer(columns...)
	switch sqlInfo.updateMechanism {
	case notIgnoredEveryColumn: //不忽略任何一列
		for i := range columns {
			if i > 0 {
				sqlInfo.AppendSQL(",")
			}
			sqlInfo.AppendSQL(columns[i]+"=?", ConvertToString(container[i]))
		}
		pkColumnName := GetPkColumnName(sqlInfo.ctx)
		condition := data.RawColumnContainer(pkColumnName)
		sqlInfo.AppendSQL(" WHERE "+pkColumnName+"=?", ConvertToString(condition[0]))
		return sqlInfo.Exec()
	case ignoredEveryEmptyColumn: //忽略每一个空列
		empty := data.NewInstance()
		emptyEqualiser := empty.RawColumnContainer(columns...)
		pkColumnName := GetPkColumnName(sqlInfo.ctx)
		k := 0
		for i := range columns {
			if columns[i] == pkColumnName { //不set条件列
				continue
			}
			value := ConvertToString(container[i])
			emptyValue := ConvertToString(emptyEqualiser[i])
			if value == emptyValue {
				continue
			}
			if k > 0 {
				sqlInfo.AppendSQL(",")
			}
			sqlInfo.AppendSQL(columns[i]+"=?", value)
			k++
		}
		condition := data.RawColumnContainer(pkColumnName)
		emptyEqualiser = empty.RawColumnContainer(pkColumnName)
		value := ConvertToString(condition[0])
		if value == ConvertToString(emptyEqualiser[0]) {
			err = errors.New(fmt.Sprintf("condition '%s' column value is empty", pkColumnName))
			FuncLog(err)
			return rowsAffected, lastInsertId, err
		}
		sqlInfo.AppendSQL(" WHERE "+pkColumnName+"=?", value)
		return sqlInfo.Exec()
	default:
		err = errors.New("no action was taken")
		FuncLog(err)
		return rowsAffected, lastInsertId, err
	}
}

func (sqlInfo *sqlInfo[T]) deleteByPk(data T) (rowsAffected int64, lastInsertId int64, err error) {
	sqlInfo.result = data
	pkColumnName := GetPkColumnName(sqlInfo.ctx)
	container := data.RawColumnContainer(pkColumnName)
	sqlInfo.AppendSQL("DELETE FROM "+data.TableName()+" WHERE "+pkColumnName+"=?", container[0])
	return sqlInfo.Exec()
}

func (sqlInfo *sqlInfo[T]) Exec() (rowsAffected int64, lastInsertId int64, err error) {
	sqlStr, err := sqlInfo.String()
	if err != nil {
		FuncLog(err)
		return rowsAffected, lastInsertId, err
	}
	var execContext sql.Result
	tx := GetContextTxConn(sqlInfo.ctx)
	startTime := time.Now().UnixNano()
	if tx != nil {
		execContext, err = tx.ExecContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
	} else if conf := getContextConf(sqlInfo.ctx); conf.SkipDefaultTransaction {
		if conf.DB == nil {
			err = errors.New("please set ctx dbConn")
			FuncLog(err)
			return rowsAffected, lastInsertId, err
		}
		execContext, err = conf.DB.ExecContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
	} else {
		err = Transaction(sqlInfo.ctx, func(ctx context.Context) error {
			tx = GetContextTxConn(ctx)
			execContext, err = tx.ExecContext(ctx, sqlStr, sqlInfo.params...)
			return err
		})
	}
	endTime := time.Now().UnixNano()
	FuncSQLLog(float64(endTime-startTime)/1e6, sqlStr, sqlInfo.params)
	if err != nil {
		FuncLog(err)
		return rowsAffected, lastInsertId, err
	}
	sqlInfo.RowsAffected, err = execContext.RowsAffected()
	if err != nil {
		FuncLog(err)
		return rowsAffected, lastInsertId, err
	}
	sqlInfo.LastInsertId, err = execContext.LastInsertId()
	if err != nil {
		FuncLog(err)
		return sqlInfo.RowsAffected, lastInsertId, err
	}
	return sqlInfo.RowsAffected, sqlInfo.LastInsertId, err
}

func SqlScript[T Table](ctx context.Context, statement string, param ...any) *sqlInfo[T] {
	info := sqlInfo[T]{}
	info.ctx = ctx
	info.mapping = noMapping
	info.updateMechanism = ignoredEveryEmptyColumn
	info.possibleQueryEmptyColumn = false
	info.AppendSQL(statement, param...)
	return &info
}

func Select[T Table](ctx context.Context, columns ...string) *sqlInfo[T] {
	if len(columns) <= 0 {
		columns = append(columns, "*")
	}
	sqlInfo := SqlScript[T](ctx, "")
	sqlInfo.AppendSQL("SELECT " + strings.Join(columns, ",") + " FROM " + sqlInfo.result.TableName() + " ")
	return sqlInfo
}

func NewSqlInfo(ctx context.Context, statement string, param ...any) rawTableSqlInfo {
	script := SqlScript[*rawTable](ctx, statement, param...)
	script.rawTable = &rawTable{}
	script.mapping = resultRawTable
	return script
}

// Insert 属性默认值遵循go语言基本类型的默认值 int=0;string="";time="0000-01-01 00:00:00" ...等
// 新增时的非自增的主键由调用方处理,可使用主键生成工具类(tcode.FuncGenId)生成
func Insert(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
	return SqlScript[Table](ctx, "").insert(t)
}

// InsertBatch 批量新增
func InsertBatch(ctx context.Context, ts *[]Table) (rowsAffected int64, lastInsertId int64, err error) {
	return SqlScript[Table](ctx, "").insertBatch(ts)
}

// Save 保存: 新增或更新
func Save(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
	exist, err := Exist(ctx, t)
	if err != nil {
		return rowsAffected, lastInsertId, err
	}
	if exist {
		return UpdateByPk(ctx, t)
	}
	return Insert(ctx, t)
}

func Exist(ctx context.Context, t Table) (bool, error) {
	pkColName := GetPkColumnName(ctx)
	pkVal := ConvertToString(t.RawColumnContainer(pkColName)[0])
	if len(pkVal) <= 0 || pkVal == "0" {
		return false, nil
	}
	var count int
	err := NewSqlInfo(ctx, fmt.Sprintf("SELECT count(%s) FROM %s WHERE %s=?", pkColName, t.TableName(), pkColName), pkVal).ToVar(&count)
	if err != nil {
		return false, err
	}
	return count > 0, nil
}

// UpdateByPk 忽略每一个空列 仅根据主键更新
func UpdateByPk(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
	return SqlScript[Table](ctx, "").updateByPk(t)
}

// UpdateNotIgnoredEveryColumnByPk 不忽略任何一列 仅根据主键更新
func UpdateNotIgnoredEveryColumnByPk(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
	sqlInfo := SqlScript[Table](ctx, "")
	sqlInfo.updateMechanism = notIgnoredEveryColumn
	return sqlInfo.updateByPk(t)
}

// DeleteByPk 仅根据主键删除
func DeleteByPk(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
	return SqlScript[Table](ctx, "").deleteByPk(t)
}

func Transaction(ctx context.Context, call func(ctx context.Context) error) error {
	db := GetContextDBConn(ctx)
	if db == nil {
		err := errors.New("please set ctx dbConn")
		FuncLog(err)
		return err
	}
	tx, err := db.BeginTx(ctx, GetContextTxOptions(ctx))
	if err != nil {
		return err
	}
	defer func() {
		if p := recover(); p != nil {
			FuncLog(p)
			err = tx.Rollback()
		} else if err != nil {
			FuncLog(err)
			err = tx.Rollback()
		} else {
			err = tx.Commit()
		}
		if err != nil {
			FuncLog(err)
		}
	}()
	ctx = WithTX(ctx, tx)
	err = call(ctx)
	if err != nil {
		return err
	}
	return nil
}
